# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Integration tests for SFT trainer correctness with golden data generated from
`maxtext/MaxText/scratch_code/generate_sft_golden_data.py`.

ATTENTION: This test should only be run on TPU v4-8. The test
may fail on different versions like v5p-8, v6e-8.

Usage:

  pytest tests/integration_tests/sft_trainer_correctness_test.py
"""

import os.path

import jsonlines
import pytest
import subprocess
import sys
import unittest
import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from transformers import AutoTokenizer

from MaxText import maxtext_utils
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_TRAIN
from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT, MAXTEXT_TEST_ASSETS_ROOT
from MaxText.input_pipeline import _input_pipeline_utils
from MaxText.layers import models
from MaxText.layers import quantizations


def get_golden_data(model_name):
  """Get the golden data for sft_trainer from maxtext/MaxText/scratch_code/generate_sft_golden_data.py."""
  input_golden_data_path = os.path.join(MAXTEXT_TEST_ASSETS_ROOT, f"golden_data_sft_{model_name}.jsonl")
  with jsonlines.open(input_golden_data_path, "r") as reader:
    return next(iter(reader))


def initialize_config():
  """Initialize configurations."""
  return pyconfig.initialize(
      [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "sft.yml")],
      run_name="test-sft-trainer-correctness",
      model_name="default",
      tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer"),
      enable_checkpointing=False,
      max_target_length=32,
      per_device_batch_size=1,
      max_prefill_predict_length=16,
      dataset_type="synthetic",
      dtype="float32",
      matmul_precision="high",
      logits_dot_in_fp32=True,
  )


def prepare_maxtext_inputs(maxtext_data, config):
  """Get tokenized inputs for MaxText."""
  tokenizer = AutoTokenizer.from_pretrained(
      config.tokenizer_path,
      add_bos_token=False,
      add_eos_token=False,
      model_max_length=config.max_target_length,
  )
  data = _input_pipeline_utils.apply_chat_template(maxtext_data, tokenizer, "messages")
  tokenized_data = _input_pipeline_utils.tokenization(
      data,
      hf_tokenizer=tokenizer,
      truncation=False,
      max_length=config.max_target_length,
      column_names=["messages"],
  )
  masked_inputs = _input_pipeline_utils.SFTPromptMasking(
      text_column_name="messages",
      completion_only=False,
      max_target_length=config.max_target_length,
      unk_id=tokenizer.unk_token_id,
  ).map(tokenized_data)

  global_batch_size = int(jax.device_count() * config.per_device_batch_size * config.gradient_accumulation_steps)
  inputs = jnp.stack([np.asarray(masked_inputs["inputs"], dtype=np.int32) for _ in range(global_batch_size)])
  inputs_segmentation = jnp.stack([(masked_inputs["inputs"] != 0).astype(np.int32) for _ in range(global_batch_size)])
  inputs_position = jnp.stack(
      [np.arange(masked_inputs["inputs"].shape[0], dtype=np.int32) for _ in range(global_batch_size)]
  )

  return {
      "inputs": inputs,
      "inputs_segmentation": inputs_segmentation,
      "inputs_position": inputs_position,
  }


def setup_maxtext_model(config):
  """Setup MaxText model."""
  init_rng = jax.random.PRNGKey(config.init_weights_seed)
  quant = quantizations.configure_quantization(config)
  devices_array = maxtext_utils.create_device_mesh(config)
  mesh = Mesh(devices_array, config.mesh_axes)
  maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
  state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None)
  return maxtext_model, state, init_rng


def get_maxtext_logits(config, maxtext_data):
  """Get logits generated by MaxText."""
  maxtext_model, state, rng = setup_maxtext_model(config)
  maxtext_logits, _ = maxtext_model.apply(
      state.params,
      maxtext_data["inputs"],
      maxtext_data["inputs_position"],
      decoder_segment_ids=maxtext_data["inputs_segmentation"],
      enable_dropout=False,
      rngs=rng,
      mutable="intermediates",
  )
  return maxtext_logits


def get_token_log_probs(logits, inputs):
  """Computes per-token log probabilities."""
  targets = inputs[:, 1:]
  log_probs = jax.nn.log_softmax(logits, axis=-1)
  log_probs = log_probs[:, :-1, :]
  # Gather the log probabilities corresponding to each target token.
  token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0]
  return token_log_probs


class SFTTrainerCorrectnessTest(unittest.TestCase):

  @classmethod
  def setUpClass(cls):
    jax.config.update("jax_default_prng_impl", "unsafe_rbg")
    if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
      os.environ["LIBTPU_INIT_ARGS"] = (
          os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
      )

    exit_code = subprocess.call(
        [
            "gsutil",
            "cp",
            "-r",
            "gs://maxtext-dataset/hf/llama2-chat-tokenizer",
            os.path.join(MAXTEXT_ASSETS_ROOT, ""),
        ]
    )
    if exit_code != 0:
      raise ValueError(f"Download tokenizer with gsutil cp failed with exit code: {exit_code}")

  @pytest.mark.skip(reason="Logit output test fragile, failing on jax upgrade to 0.6.2 b/425997645")
  @pytest.mark.integration_test
  @pytest.mark.tpu_only  # ATTENTION: Only run on TPU V4-8
  def test_sft_trainer_correctness(self):
    config = initialize_config()
    golden_data = get_golden_data(config.model_name)
    maxtext_data = prepare_maxtext_inputs(golden_data["data"], config)

    assert golden_data["tokens"] == maxtext_data["inputs"][0].tolist()
    assert golden_data["attention_mask"] == maxtext_data["inputs_segmentation"][0].tolist()

    maxtext_logits = get_maxtext_logits(config, maxtext_data)
    token_log_probs = get_token_log_probs(maxtext_logits, maxtext_data["inputs"])
    golden_token_log_probs = np.array(golden_data["token_log_probs"])

    max_diff = np.max(np.abs(np.subtract(token_log_probs[0], golden_token_log_probs)))
    print("Max numerical difference:", max_diff)

    assert jax.numpy.allclose(
        token_log_probs[0],
        golden_token_log_probs,
        rtol=1e-5,
        atol=1e-8,
        equal_nan=False,
    ).all()
